/**
 * Copyright (c) 2017 Saleem Rashid
 *
 * Permission is hereby granted, free of charge, to any person obtaining
 * a copy of this software and associated documentation files (the "Software"),
 * to deal in the Software without restriction, including without limitation
 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
 * and/or sell copies of the Software, and to permit persons to whom the
 * Software is furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included
 * in all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, E1PRESS
 * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES
 * OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
 * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
 * OTHER DEALINGS IN THE SOFTWARE.
 */

#include "base32.h"

#include <string.h>

const char* BASE32_ALPHABET_RFC4648 = "ABCDEFGHIJKLMNOPQRSTUVWXYZ23456789";

static inline void base32_5to8(const uint8_t* in, uint8_t length, uint8_t* out);
static inline bool
    base32_8to5(const uint8_t* in, uint8_t length, uint8_t* out, const char* alphabet);
static inline void base32_8to5_raw(const uint8_t* in, uint8_t length, uint8_t* out);

static inline int base32_encode_character(uint8_t decoded, const char* alphabet);
static inline int base32_decode_character(char encoded, const char* alphabet);

char* base32_encode(
    const uint8_t* in,
    size_t inlen,
    char* out,
    size_t outlen,
    const char* alphabet) {
    size_t length = base32_encoded_length(inlen);
    if(outlen <= length) {
        return NULL;
    }

    base32_encode_unsafe(in, inlen, (uint8_t*)out);

    for(size_t i = 0; i < length; i++) {
        int ret = base32_encode_character(out[i], alphabet);

        if(ret == -1) {
            return NULL;
        } else {
            out[i] = ret;
        }
    }

    out[length] = '\0';
    return &out[length];
}

uint8_t*
    base32_decode(const char* in, size_t inlen, uint8_t* out, size_t outlen, const char* alphabet) {
    size_t length = base32_decoded_length(inlen);
    if(outlen < length) {
        return NULL;
    }

    if(!base32_decode_unsafe((uint8_t*)in, inlen, (uint8_t*)out, alphabet)) {
        return NULL;
    }

    return &out[length];
}

void base32_encode_unsafe(const uint8_t* in, size_t inlen, uint8_t* out) {
    uint8_t remainder = inlen % 5;
    size_t limit = inlen - remainder;

    size_t i = 0, j = 0;
    for(i = 0, j = 0; i < limit; i += 5, j += 8) {
        base32_5to8(&in[i], 5, &out[j]);
    }

    if(remainder) base32_5to8(&in[i], remainder, &out[j]);
}

bool base32_decode_unsafe(const uint8_t* in, size_t inlen, uint8_t* out, const char* alphabet) {
    uint8_t remainder = inlen % 8;
    size_t limit = inlen - remainder;

    size_t i = 0, j = 0;
    for(i = 0, j = 0; i < limit; i += 8, j += 5) {
        if(!base32_8to5(&in[i], 8, &out[j], alphabet)) {
            return false;
        }
    }

    if(remainder && !base32_8to5(&in[i], remainder, &out[j], alphabet)) {
        return false;
    }

    return true;
}

size_t base32_encoded_length(size_t inlen) {
    uint8_t remainder = inlen % 5;

    return (inlen / 5) * 8 + (remainder * 8 + 4) / 5;
}

size_t base32_decoded_length(size_t inlen) {
    uint8_t remainder = inlen % 8;

    return (inlen / 8) * 5 + (remainder * 5) / 8;
}

void base32_5to8(const uint8_t* in, uint8_t length, uint8_t* out) {
    if(length >= 1) {
        out[0] = (in[0] >> 3);
        out[1] = (in[0] & 7) << 2;
    }

    if(length >= 2) {
        out[1] |= (in[1] >> 6);
        out[2] = (in[1] >> 1) & 31;
        out[3] = (in[1] & 1) << 4;
    }

    if(length >= 3) {
        out[3] |= (in[2] >> 4);
        out[4] = (in[2] & 15) << 1;
    }

    if(length >= 4) {
        out[4] |= (in[3] >> 7);
        out[5] = (in[3] >> 2) & 31;
        out[6] = (in[3] & 3) << 3;
    }

    if(length >= 5) {
        out[6] |= (in[4] >> 5);
        out[7] = (in[4] & 31);
    }
}

bool base32_8to5(const uint8_t* in, uint8_t length, uint8_t* out, const char* alphabet) {
    if(length == 1 || length == 3 || length == 6 || length > 8) {
        return false;
    }

    if(alphabet) {
        uint8_t decoded[length];
        memset(decoded, 0, sizeof(decoded));

        for(size_t i = 0; i < length; i++) {
            int ret = base32_decode_character(in[i], alphabet);

            if(ret == -1) {
                return false;
            } else {
                decoded[i] = ret;
            }
        }

        base32_8to5_raw(decoded, length, out);
    } else {
        base32_8to5_raw(in, length, out);
    }

    return true;
}

void base32_8to5_raw(const uint8_t* in, uint8_t length, uint8_t* out) {
    if(length >= 2) {
        out[0] = (in[0] << 3);
        out[0] |= (in[1] >> 2);
    }

    if(length >= 4) {
        out[1] = (in[1] & 3) << 6;
        out[1] |= (in[2] << 1);
        out[1] |= (in[3] >> 4);
    }

    if(length >= 5) {
        out[2] = (in[3] & 15) << 4;
        out[2] |= (in[4] >> 1);
    }

    if(length >= 7) {
        out[3] = (in[4] & 1) << 7;
        out[3] |= (in[5] << 2);
        out[3] |= (in[6] >> 3);
    }

    if(length >= 8) {
        out[4] = (in[6] & 7) << 5;
        out[4] |= (in[7] & 31);
    }
}

int base32_encode_character(uint8_t decoded, const char* alphabet) {
    if(decoded >> 5) {
        return -1;
    }

    if(alphabet == BASE32_ALPHABET_RFC4648) {
        if(decoded < 26) {
            return 'A' + decoded;
        } else {
            return '2' - 26 + decoded;
        }
    }

    return alphabet[decoded];
}

int base32_decode_character(char encoded, const char* alphabet) {
    if(alphabet == BASE32_ALPHABET_RFC4648) {
        if(encoded >= 'A' && encoded <= 'Z') {
            return encoded - 'A';
        } else if(encoded >= 'a' && encoded <= 'z') {
            return encoded - 'a';
        } else if(encoded >= '2' && encoded <= '7') {
            return encoded - '2' + 26;
        } else {
            return -1;
        }
    }

    const char* occurrence = strchr(alphabet, encoded);

    if(occurrence) {
        return occurrence - alphabet;
    } else {
        return -1;
    }
}
